dlbs: U-Nets in a Self-Driving Segmentation TaskΒΆ
- Name: Nils Fahrni
- Submission Date: 17.01.2025
- dlbs: U-Nets in a Self-Driving Segmentation Task
- Project Specifications & Introduction
- Dataset
- Exploration
- Training and Evaluation Skeleton
- Overfit
- Regularization
- Tuning the model
- Squeeze The Juice
- Implementing Attention into the U-Net
- Training the Attention U-Net
- Experiment 1: Attention U-Net with
dropout_prob=0.1 - Experiment 2: Attention U-Net with
dropout_prob=0.2 - Experiment 3: Attention U-Net with
dropout_prob=0.3 - Summarizing the Attention Experiments
- Taking A Look At Saliency Attention Maps
- Looking At The Predicted Segmentations
- Global Summary
Project Specifications & IntroductionΒΆ
Research Question: "How do segmentation models perform between scenes of city streets and non-city streets in the BDD100K dataset?"
My motivation for this research question and domain stems from the last summer break during my apprenticeship in the Summer of 2019. I went to San Francisco and experienced Waymo for the first time. I was fascinated by the technology behind autonomous driving. On the last day, going back to the airport, my Uber Driver even had a Comma.ai device installed in his old Honda Civic - A device that I later learned was developed open source. When I got home I started digging deeper into the technology and decided I want to one day develop such systems myself. This was when I applied for the Data Science program at FHNW.
Now I still follow the same goal. As I am standing before my last semester in this undergraduate program and after this Deep Learning course in the context of computer vision I can confidently say I have learned about all the basics to start moving closer towards my goal of building and contributing to autonomous systems. I have applied to numerous master programs in the field of Machine Learning at universities that are known for their research in autonomous systems.
This project was a great stepping stone to finally work with a dataset that is used in the industry and can find various applications.
import random
import numpy as np
RANDOM_SEED = 1337
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
Note: Usually, setting the random seed only happens in the second step according to Karpathy's recipe. However, I set it here to ensure reproducibility of the following exploration as I am sampling some data.
DatasetΒΆ
The dataset used in this Challenge is the Berkeley Deep Drive Dataset found at https://arxiv.org/abs/1805.04687
It comprises 10000 images which are pre-partitioned into a train, val and test subset. Only the train and val partitions have a corresponding segmentation mask (ground truth), the test partition does not since this dataset was published as a challenge, thus holding out the ground truth for evaluation of submitted models.
The research question in this challenge requires the data to have a scene label. In the label_matching.ipynb notebook I first and foremost explored if every image does contain a scene label. I found out that only the pre-partitioned test subset of the dataset I use in this challenge has scene labels and of that partition, only 3426 sample pairs do so. Additionally I performed some more thorough analysis and statistical testing (see linked notebook) on the matched scene labels and noted that I will primarily be able to make performative assumptions on city streets and non-city streets which mainly contain either residential scenes or scenes on highways.
import os
from data import BDD100KDataset
BASE_PATH = os.path.join('data', 'bdd100k')
dataset = BDD100KDataset(base_path=BASE_PATH)
from tqdm import tqdm
import numpy as np
import torch
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.colors import ListedColormap, BoundaryNorm
import seaborn as sns
Getting an OverviewΒΆ
The first step is to look at some samples to get a grasp of how the images and masks in the dataset are structured.
Here I will first and foremost define a global colormap lookup for the future plots so all classes get dispalyed in the same way throughout this notebook.
from data.resolved_names import class_dict
colors = matplotlib.colormaps['tab20'](np.linspace(0, 1, 19)) # for 19 classes (255 exlcuded)
norm = BoundaryNorm(np.arange(-0.5, 19.5, 1), 19) # to create an exact boundary between the classes
cmap = ListedColormap(colors)
full_colors = colors
class_colors = {class_name: full_colors[i][:3] for i, class_name in enumerate(class_dict)}
class_colors_hex = {class_name: matplotlib.colors.to_hex(color) for class_name, color in class_colors.items()}
The plot_images_and_masks function plots num_samples (in this case 5) rows of images and their corresponding segmentation masks that will be used for training. The dataset consists of 19 classes whereby the 20th class would be class 255 which just describes the background or unlabeled category, so objects that do not get their own label. In the following plot, the class 255 was masked out and will show up as negative white space.
def plot_images_and_masks(samples, class_colors=None, figsize=(12, 20)):
"""
Plots side-by-side views of images and their corresponding segmentation masks.
Args:
samples (Dataset): An instance of the BDD100KDataset.
num_samples (int): Number of samples to visualize.
class_colors (dict): Optional dictionary mapping class names to RGB colors.
"""
fig, axes = plt.subplots(len(samples), 2, figsize=figsize)
fig.suptitle('BDD100K Dataset Samples', fontsize=16, y=1)
axes = np.atleast_2d(axes)
for i, (image, label, scene) in enumerate(samples):
mask_cleaned = np.where(label == 255, np.nan, label)
ax_img, ax_mask = axes[i]
ax_img.imshow(image.permute(1, 2, 0).numpy())
ax_img.set_title(f"Image")
ax_img.text(25, 25, scene, fontsize=10, color='white', bbox=dict(facecolor='black', alpha=0.8, pad=2), va='top', ha='left')
ax_img.axis('off')
ax_mask.imshow(mask_cleaned, cmap=cmap, norm=norm, interpolation='none')
ax_mask.set_title("Segmentation Mask")
ax_mask.axis('off')
if class_colors:
legend_patches = [
mpatches.Patch(color=color, label=class_dict[class_name])
for class_name, color in class_colors.items()
]
ax_mask.legend(
handles=legend_patches,
loc='right',
fontsize='x-small'
)
plt.tight_layout()
plt.subplots_adjust()
plt.show()
indices = np.random.choice(len(dataset), 5, replace=False)
plot_images_and_masks(dataset[indices], class_colors=class_colors_hex)
The samples show some the nuances of the dataset and how it is built up:
- All samples show the cockpit of the recording POV car which mostly seems to be at the edges of the images and gets assigned class
255. This seems to just be a design choice since one could also interpret the cockpit to belong to the classcar. - The first image also shows a dumpster on the right
sidewalkwhich is a case of an unlabeled class (255). - On the
highwayscenes the median barriers seem to also be unlabeled, something I would have expected to get the labelwall. - Not all images have the same granularity and detail when it comes to the labels. For example in the third image, the
fenceon the left side of the street gets the labelfencebut on the right side thefencestructures do not seem to get that label and are rather just counted asbuilding. This can also be observed on the fourth image, where some town houses are visible on the left side of the image, these are unlabeled and do not get thebuildingclass. - Also the scene seems to not always be clear. For example, the last image shows a
tunnel(which is a possible scene attribute) but the image's actual scene ishighway. This could be due to the fact thathighwayhas a higher "overruling" order in the scene hierarchy or just be due to the tunnel being some sort of underpass.
Classes and their LocationΒΆ
This exploration step aims to show how the images are structured spatially. For now, we only have taken a look into some of the images directly. The function plot_class_heatmaps however, will take a look at a larger sample size (in this case 1000 images, cut down for computational reasons). It plots 19 heatmaps (one for each class), describing where each class usually occurs in the images.
def plot_class_heatmaps(dataset, n_samples):
"""
Generate and plot spatial heatmaps for all classes in the dataset.
Args:
dataset: Dataset object with data samples containing label masks.
n_samples: Number of samples to process from the dataset in total.
"""
heatmaps = {class_id: np.zeros((dataset[0][0].shape[1], dataset[0][0].shape[2]), dtype=np.float32) for class_id in class_dict.keys()}
class_sample_counts = {class_id: 0 for class_id in class_dict.keys()}
indices = np.random.permutation(len(dataset))[:n_samples]
for idx in tqdm(indices, desc="Processing Dataset"):
_, label, _ = dataset[idx]
label_np = label.numpy()
for class_id in class_dict.keys():
mask = (label_np == class_id)
heatmaps[class_id] += mask.astype(np.float32)
class_sample_counts[class_id] += mask.sum()
for class_id in heatmaps:
if class_sample_counts[class_id] > 0:
heatmaps[class_id] /= class_sample_counts[class_id]
fig, axs = plt.subplots(4, 5, figsize=(20, 15))
fig.suptitle("Spatial Heatmaps for all Classes", fontsize=20)
for class_id, class_name in class_dict.items():
ax = axs[class_id // 5, class_id % 5]
sns.heatmap(heatmaps[class_id], ax=ax, cmap="viridis", cbar=False)
ax.set_title(class_name)
ax.axis('off')
for i in range(len(class_dict), 4 * 5):
fig.delaxes(axs[i // 5, i % 5])
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.show()
plot_class_heatmaps(dataset, n_samples=1000)
Processing Dataset: 100%|ββββββββββ| 1000/1000 [01:22<00:00, 12.19it/s]
- The smoothest heatmaps are of the classes
sky,car,road,vegetationandbuilding. I hypothesize that this is due to the high frequency of these classes. These classes are all at reasonable locations, for example the road heatmap shows that it is always at the bottom half of the image, centered, likewise the sky that is always at the top, also centered. - In the lesser populated heatmaps we can clearly see the shapes of the individual objects. For example, in
rider(which is meant to represent a bike/bicycle rider) clear shapes of people are visible. - These less densely populated heatmaps also show some flaws of the dataset, not everything seems to be labelled correctly, but the majority is. In
riderwe can for example see the back of acar, or inmotorcyclealso acaris visible on the right side. Inbicyclewe can see a shape ofvegetationwith street lanterns (pole) in the negative space. - It becomes also visible that the
trainclass probably has the least occurences, it shows the least number of shapes.
Class OccurenceΒΆ
Since we now have looked at the spatial distribution of each class and made some assumptions based on the 1000 samples, let's now look at the actual frequency of each class in the dataset. This will give us some insight into how many times each class occurs (i.e. what classes a model might struggle with classifying).
def compute_co_occurrence_matrix(dataset):
"""
Compute and plot a co-occurrence matrix for all classes in the dataset.
Args:
dataset: Dataset object with data samples containing label masks.
"""
global class_dict
num_classes = len(class_dict)
class_names = list(class_dict.values())
co_occurrence_matrix = np.zeros((num_classes, num_classes), dtype=np.int32)
for idx in tqdm(range(len(dataset)), desc="Processing Dataset for Co-Occurrence Matrix"):
_, label, _ = dataset[idx]
label_np = label.numpy()
unique_classes = np.unique(label_np)
for i in range(len(unique_classes)):
for j in range(i, len(unique_classes)):
class_i = unique_classes[i]
class_j = unique_classes[j]
if class_i < num_classes and class_j < num_classes:
co_occurrence_matrix[class_i, class_j] += 1
if class_i != class_j:
co_occurrence_matrix[class_j, class_i] += 1
plt.figure(figsize=(12, 10))
ax = sns.heatmap(co_occurrence_matrix, annot=True, fmt="d", cmap="Blues", xticklabels=class_names, yticklabels=class_names)
plt.title("Class Co-Occurrence Matrix for Dataset", pad=20)
plt.xlabel("Class")
plt.ylabel("Class")
ax.xaxis.tick_top()
ax.xaxis.set_label_position('top')
plt.xticks(rotation=45, ha="left")
plt.yticks(rotation=0)
plt.tight_layout()
plt.show()
compute_co_occurrence_matrix(dataset)
Processing Dataset for Co-Occurrence Matrix: 100%|ββββββββββ| 3426/3426 [01:31<00:00, 37.39it/s]
The Co-Occurence matrix shows on how many images each class occurs throughout the dataset (diagonal values), as well on how many images a class co-occurs with another class.
- Cars (
car) are the most frequent on the dataset with3361images where at least one car occurs. - Closely after
car,roadseems to be the second most frequent class, unexpectedly, it doesn't occur on every image. Out of3426samples,145images in this dataset are missing aroadclass which is surprising. - Other frequent classes are
pole(either from traffic lights or generally street lanterns and for street signs),sidewalk(most likely found in theresidentialscenes),skyandbuilding. Theskylabel is also not in every image, but this is an observation we have already made in one of the first plots, as sometimes images can potentially be recorded in atunnelor obstructed through abuildingorvegetation. - The suspected candidates for lesser occuring classes such as
train,rider,motorcycleandbicycleare validated to have a low number of samples in my subset. - The highly co-occuring pairs can be clearly seen in the matrix at places where the matrix shows a darker square area. We can see such an area in the top left corner for
road,sidewalkandbuilding. These have generally high occurence, so they are naturally expected to overlap. Alsopoleandtraffic signare occuring together often since traffic signs are mostly attached to poles.
Samples without RoadΒΆ
As the Co-Occurence matrix has shown that there are samples without any road, I am taking a closer look at the samples to investigate wheter this is an error by the dataset or an error inside my Co-Occurence matrix construction.
samples_without_road = []
for i, (image, seg, scene) in enumerate(dataset):
if 0 not in torch.unique(seg):
samples_without_road.append(i)
indices = np.random.choice(samples_without_road, 5, replace=False)
plot_images_and_masks(dataset[indices], class_colors=class_colors_hex)
The road seems to be missing on the dataset side. The pictures clearly show road but the segmentation masks clearly missclassify them or even assigned unlabeled.
I will not further filter out these "malformed" samples as they also contain correctly assigned labels, especially for cars and buildings and walls. The road is an overly-proportional occuring class, so these 145 samples will most likely not skew the model's ability to classify road but rather contribute a regularizing effect.
Image SizesΒΆ
As the last step of exploration I will look at how much I can resize the images without losing too much information. This is important for training to improve performance both time and memory-wise, though it is important to not resize too harshly as it could lead to information loss - Something we want to avoid in self driving.
from torchvision.transforms import Resize, ToPILImage
def create_resolution_grid(samples):
"""
Creates a grid showing images at various resolutions.
Args:
samples (list): A list of tuples from the dataset (image, label, scene).
"""
resolutions = [None, 256, 128, 64, 32]
resize_transforms = [None] + [Resize((res, res)) for res in resolutions[1:]]
to_pil = ToPILImage()
num_samples = len(samples)
num_resolutions = len(resolutions)
fig, axes = plt.subplots(num_samples, num_resolutions, figsize=(num_resolutions * 3, num_samples * 3))
fig.suptitle('BDD100K in Different Resolutions', fontsize=16, y=1.02)
for i, (image, _, _) in enumerate(samples):
for j, resize_transform in enumerate(resize_transforms):
if resize_transform:
resized_image = resize_transform(image)
else:
resized_image = image
resized_image = to_pil(resized_image)
ax = axes[i, j]
ax.imshow(resized_image)
ax.axis('off')
if i == 0:
if resolutions[j]:
ax.set_title(f"{resolutions[j]}x{resolutions[j]} pixels", fontsize=10)
else:
ax.set_title("Original", fontsize=10)
plt.tight_layout()
plt.show()
indices = np.random.choice(samples_without_road, 5, replace=False)
create_resolution_grid(dataset[indices])
The explored resolutions show that around 128x128 pixels and 64x64 pixels there is quite a large information loss. Objects in distance can not be distinguished anymore at 64x64 but are still acceptably visible in the 128x128 resolution. The 128x128 resolution seems to be the best trade-off between information loss and memory usage, thus, I will go with that resolution to train my models.
Reducing Number of ClassesΒΆ
Preliminary to the actual training I have already tested to train the models on all 19 classes. This has shown that the models are not able to learn low-occuring classes. I suspect this was due to the low sample size of the dataset that I am using which overlaps with some scene attribute. In this step I will however not just look at bare frequency of each class but at the top pixel count rather. This is just another design decision to make the model more robust on the larger class objects since I am reducing the resolution of the input images and thus would lose a lot of information of classes with a lower pixel count (such as poles for example, which are very thin and have a low pixel count but overall high occurence in the dataset).
The following function will calculate the pixel count for each class in the dataset and then return the top_n (in this case 5) classes.
import torch
from tqdm import tqdm
from data.resolved_names import class_dict
def calc_top_n_pixels(dataset, top_n=5, n_classes=19):
"""
Calculates the total number of pixels for each class in the dataset and identifies the top_n classes with the most pixels.
Args:
dataset (Dataset): The dataset to analyze.
top_n (int, optional): Number of top classes to display based on pixel count. Defaults to 5.
n_classes (int, optional): Total number of classes. Defaults to 19.
Returns:
list of tuples: A list containing tuples of (class_id, class_name, pixel_count) for the top_n classes.
"""
class_counts = torch.zeros(n_classes, dtype=torch.long)
print("Counting pixels for each class...")
for idx in tqdm(range(len(dataset)), desc="Processing Images"):
_, label, _ = dataset[idx]
if not isinstance(label, torch.Tensor):
raise TypeError(f"Expected label to be a torch.Tensor, but got {type(label)}")
if not torch.is_floating_point(label):
label = label.long()
else:
label = label.to(torch.long)
label_flat = label.view(-1)
mask = (label_flat < n_classes) & (label_flat >= 0)
label_valid = label_flat[mask]
if label_valid.numel() == 0:
continue
counts = torch.bincount(label_valid, minlength=n_classes)
class_counts += counts
top_counts, top_ids = torch.topk(class_counts, top_n)
top_classes = []
print(f"\nTop {top_n} Classes with the Most Pixels:")
print("--------------------------------------------------")
print(f"{'Rank':<5} {'Class ID':<10} {'Class Name':<20} {'Pixel Count':<15}")
print("--------------------------------------------------")
for rank, (cls_id, count) in enumerate(zip(top_ids.tolist(), top_counts.tolist()), start=1):
class_name = class_dict.get(cls_id, "Unknown")
top_classes.append((cls_id, class_name, count))
print(f"{rank:<5} {cls_id:<10} {class_name:<20} {count:<15}")
print("--------------------------------------------------")
return top_classes
calc_top_n_pixels(dataset, top_n=5)
Counting pixels for each class...
Processing Images: 100%|ββββββββββ| 3426/3426 [01:00<00:00, 56.63it/s]
Top 5 Classes with the Most Pixels: -------------------------------------------------- Rank Class ID Class Name Pixel Count -------------------------------------------------- 1 0 road 668329989 2 2 building 452690852 3 10 sky 418227625 4 8 vegetation 384303219 5 13 car 254085637 --------------------------------------------------
[(0, 'road', 668329989), (2, 'building', 452690852), (10, 'sky', 418227625), (8, 'vegetation', 384303219), (13, 'car', 254085637)]
The top 5 classes are:
roadbuildingskyvegetationcar
To just train on these classes I have created a remapping transformation, found in data/utils.py as RemapClasses. This transformation will remap all classes that are not in the top 5 to the unlabeled class 255. This transformation will be applied to the dataset before training. Since these top 5 classes are not linearly numbered through, the RemapClasses transformation will remap them according to| the dictionary defined in the Hydra configuration configs/dataset/bdd100k.yaml.
During training, the dataset will be initialized with the RemapClasses transformation and the model will be trained on the remapped dataset.
Partitioning the Dataset & TransformingΒΆ
Before starting to train models, I am splitting the dataset into a training, validation and test partition. This naturally happens when calling the train.py script with Hydra. In this notebook I am doing the same to, first of all, find normalization parameters since these should only be calculated on the training set (to prevent leakage). Later on I will also use these samples to look at some classification samples by passing these through the model and then analyzing the segmentation output.
To make sure that I am immitating the same partitioning and transformation steps as in the training script, I am importing the hydra configuration and apply the same parameters to the dataset and the splitting function. It also gets the same random_seed to make sure that the partitioning is the same.
import hydra
from hydra import compose, initialize
from data.utils import split_dataset
with initialize(version_base=None, config_path="./configs"):
cfg = compose(config_name="config")
train_dataset, val_dataset, test_dataset = split_dataset(dataset,
train_ratio=cfg.dataset.train_ratio,
val_ratio=cfg.dataset.val_ratio,
test_ratio=cfg.dataset.test_ratio,
random_seed=cfg.seed)
import torch
from tqdm import tqdm
def calculate_rgb_metrics(dataset):
"""
Calculate the mean and standard deviation of the RGB channels for a dataset.
Args:
dataset (Dataset): The dataset to calculate metrics for.
Returns:
tuple: A tuple containing two lists:
- mean (list): Mean values for R, G, B channels.
- std (list): Standard deviation values for R, G, B channels.
"""
channel_sum = torch.zeros(3)
channel_sum_squared = torch.zeros(3)
total_pixels = 0
print("Calculating mean and standard deviation...")
for idx in tqdm(range(len(dataset)), desc="Processing Images"):
image, _, _ = dataset[idx]
if not isinstance(image, torch.Tensor):
raise TypeError(f"Expected image to be a torch.Tensor, but got {type(image)}")
if image.dtype != torch.float32:
image = image.float()
if image.max() > 1.0:
image = image / 255.0
_, height, width = image.shape
num_pixels = height * width
total_pixels += num_pixels
channel_sum += image.sum(dim=[1, 2])
channel_sum_squared += (image ** 2).sum(dim=[1, 2])
mean = channel_sum / total_pixels
std = (channel_sum_squared / total_pixels - mean ** 2) ** 0.5
mean_list = [round(m.item(), 4) for m in mean]
std_list = [round(s.item(), 4) for s in std]
print("\nRGB Channel Metrics:")
print(f"Mean: {mean_list}")
print(f"Std: {std_list}")
return mean_list, std_list
calculate_rgb_metrics(train_dataset)
Calculating mean and standard deviation...
Processing Images: 100%|ββββββββββ| 2398/2398 [00:40<00:00, 58.99it/s]
RGB Channel Metrics: Mean: [0.3649, 0.3997, 0.4047] Std: [0.2528, 0.2645, 0.2754]
([0.3649, 0.3997, 0.4047], [0.2528, 0.2645, 0.2754])
The resulting metrics are then copied over to the dataset config in configs/dataset/bdd100k.yaml to make sure that the training script uses the same parameters.
transformed_dataset = BDD100KDataset(base_path='./data/bdd100k',
transform=hydra.utils.instantiate(cfg.dataset.transform),
target_transform=hydra.utils.instantiate(cfg.dataset.target_transform))
transformed_train_dataset, transformed_val_dataset, transformed_test_dataset = split_dataset(transformed_dataset,
train_ratio=cfg.dataset.train_ratio,
val_ratio=cfg.dataset.val_ratio,
test_ratio=cfg.dataset.test_ratio,
random_seed=cfg.seed)
To ensure the splitting works as intended and does not leak any samples between the partitions, I am checking if there is an unintended overlap.
from data.utils import check_dataset_overlap
check_dataset_overlap(transformed_train_dataset, transformed_test_dataset, transformed_val_dataset)
--- Overlap Report --- βοΈ No overlap detected between train and validation sets. βοΈ No overlap detected between train and test sets. βοΈ No overlap detected between validation and test sets.
The output shows that the code seems to work as expected; No samples are leaking into another partition.
Ensuring Equal DistributionΒΆ
Lastly I want to check if my partitioning has resulted in an equal distribution of the labels.
import numpy as np
from collections import Counter
import matplotlib.pyplot as plt
from tqdm import trange
def map_class_names_and_order(class_distribution, class_dict):
ordered_classes = sorted(class_dict.keys())
class_names = [class_dict[class_id] for class_id in ordered_classes if class_id in class_distribution]
proportions = [class_distribution[class_id] for class_id in ordered_classes if class_id in class_distribution]
return class_names, proportions
def plot_class_distribution(ax, class_distribution, title, class_dict):
class_names, proportions = map_class_names_and_order(class_distribution, class_dict)
bars = ax.bar(class_names, proportions, color='skyblue', edgecolor='black')
for bar, proportion in zip(bars, proportions):
ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height(),
f"{proportion * 100:.2f}%", ha='center', va='bottom', fontsize=9)
ax.grid(axis='y', linestyle='--', alpha=0.7)
ax.set_xlabel('Class')
ax.set_ylabel('Proportion of Pixels')
ax.set_title(title)
ax.tick_params(axis='x', rotation=45, labelsize=9)
ax.set_ylim(0, max(proportions)*1.1)
def analyze_class_distribution(dataset, dataset_name):
class_counts = Counter()
for idx in trange(len(dataset), desc=f"Analyzing {dataset_name}"):
try:
_, mask, _ = dataset[idx]
mask_array = np.array(mask)
unique, counts = np.unique(mask_array, return_counts=True)
class_counts.update(dict(zip(unique, counts)))
except Exception as e:
print(f"Error processing index {idx}: {e}")
continue
total_pixels = sum(class_counts.values())
class_distribution = {cls: count / total_pixels for cls, count in class_counts.items()}
return class_counts, class_distribution
train_class_counts, train_class_distribution = analyze_class_distribution(transformed_train_dataset, dataset_name="Train")
val_class_counts, val_class_distribution = analyze_class_distribution(transformed_val_dataset, dataset_name="Validation")
test_class_counts, test_class_distribution = analyze_class_distribution(transformed_test_dataset, dataset_name="Test")
Analyzing Train: 100%|ββββββββββ| 2398/2398 [00:33<00:00, 72.34it/s] Analyzing Validation: 100%|ββββββββββ| 685/685 [00:09<00:00, 73.39it/s] Analyzing Test: 100%|ββββββββββ| 343/343 [00:04<00:00, 73.40it/s]
from data.resolved_names import class_dict
fig, axes = plt.subplots(3, 1, figsize=(6, 11))
plot_class_distribution(axes[0], train_class_distribution, "Train Class Distribution", class_dict)
plot_class_distribution(axes[1], val_class_distribution, "Validation Class Distribution", class_dict)
plot_class_distribution(axes[2], test_class_distribution, "Test Class Distribution", class_dict)
plt.tight_layout()
plt.show()
The output shows that the distributions are roughly equal between the partitions. Though, between classes there is some imbalance. This will later on have to be addressed by using an appropriate loss function that takes in class weights into account.
The preprocessed data is now ready to be used for training.
Defining Some Testing SamplesΒΆ
To later on look at actual outputs of segmentation masks that the models produce I defining some samples here from the test partition to pass through the models and look at.
from data.utils import unnormalize
from analyzer import mean, std
selected_indices = [47, 60, 125, 133, 240, 313]
def plot_images(dataset, selected_indices, figsize=(15, 5)):
samples = dataset[selected_indices]
fig, ax = plt.subplots(1, len(samples), figsize=figsize)
fig.suptitle(f"Selected {len(samples)} Testing Samples", fontsize=16, y=.88)
for i, (image, mask, scene) in enumerate(samples):
ax[i].imshow(unnormalize(image, mean, std).permute(1, 2, 0))
ax[i].set_title(f"Sample {selected_indices[i]}")
ax[i].text(5, 5, scene, fontsize=12, color='white',
bbox=dict(facecolor='black', alpha=0.7, pad=2),
va='top', ha='left')
ax[i].axis('off')
plt.tight_layout()
plt.show()
plot_images(transformed_test_dataset, selected_indices)
selected_samples = transformed_test_dataset[selected_indices]
Here I have selected 3 samples in city scenes and 3 samples in non-city scenes. I have intentionally selected some interesting/difficult samples, or at least a variety. Here are my reasons for each sample:
47: Shows a residential area with a lot of vegetation but has a lot of shadow. I want to see how the models handle shadows that show some shape of an object but not of the correct color.60: This image shows a packed city street in rainy weather. The sky is clearly overcast, thus not the usual blue tone and the streets are wet (so they mirror some shades).125: This image is a city scene but at night, so there is some illumination from the lights but towards the upper end it gets darker.133: Shows a residential area but in rather deep snow. Some of the cars are covered as well as the streets. I want to see how the models handle the snow.240: Shows a darker city scene with overcast. The right glass building reflects the other side of the street with what looks like vegetation. I want to see if the model would classify this as abuildingorvegetation.313: I chose this last sample as an "easy" sample. It shows a relatively clear highway scene with just a few cars in distance and a lot ofvegetationandsky.
As already mentioned, these samples all stem from the test partiton, so the models have never learned directly on them. This will give me a good insight into how well the models generalize.
Training and Evaluation SkeletonΒΆ
Karpathy: Set up the end-to-end training/evaluation skeleton + get dumb baselines
In the next step I establish the training and evaluation skeleton.
Trainer ClassΒΆ
To achieve this, I have set up a Trainer.py-class that houses training, validation and testing.
The Core Methods For TrainingΒΆ
The core of the Trainer class is the run-method, this is the entry point to training.
- Running the Training Loop (
run):- If the model doesnβt already exist:
- Initializes model weights.
- Iterates over epochs:
- Trains for one epoch and validates on the validation set.
- Logs metrics to
wandband console. - Saves the model if validation loss improves.
- Checks for early stopping based on validation loss and stops if patience is exceeded.
- Skips training if a model with the same name exists.
- If the model doesnβt already exist:
First the run method calls the training if the model does not exist already:
- Training a Single Epoch (
_train_epoch):- Sets the model to training mode.
- Resets IoU metrics and iterates over batches in the training loader:
- Moves data to the appropriate device.
- Clears gradients, computes outputs, calculates loss, and updates model weights.
- Updates IoU metrics, ignoring labels with value 255 (e.g., "ignore" class).
- Computes and returns epoch-level training loss, global IoU, and per-class IoU.
After each epoch, the model is validated on the validation set, namely, when the run method calls _validate_epoch:
- Validate a Single Epoch (
_validate_epoch):- Similar to
_train_epochbut runs in evaluation mode withtorch.no_grad():- Ensures no gradients are computed to save memory and speed up computations.
- Returns validation loss, global IoU, and per-class IoU.
- The model's weights will always get saved if it reaches a new low validation loss.
- If an
early_stopping_patienceis set, the model will quit training if the validation loss does not improve forearly_stopping_patienceepochs.
- Similar to
After training and validation was concluded, the Trainer can be called to test the model on the test set:
- Testing the Model (
test):- Sets the model to evaluation mode.
- Iterates over the test data loader:
- Processes batches without computing gradients.
- Updates global and per-class IoU metrics, differentiating between "city" and "non-city" scenes.
- Logs test metrics, including IoU for different scene types, to
wandband console.
from trainer import Trainer
Train ScriptΒΆ
The training script, not to be confused with the Trainer class, is the global script that can be called by the user to train a model.
When a new training run is called, the script performs the following steps:
Initialize Configuration:
- Uses Hydra to load configuration settings (
cfg) from YAML files. - Logs the run name and full configuration for reproducibility.
- Uses Hydra to load configuration settings (
Set Working Directory:
- Ensures the script runs from the original working directory using Hydra utilities.
Dataset Preparation:
- Loads the
BDD100KDatasetwith specified data transformations. - Optionally subsets the dataset for overfitting tests.
- Splits the dataset into training, validation, and test sets based on specified ratios.
- Loads the
Data Loaders:
- Creates PyTorch
DataLoaderobjects for the train, validation, and test datasets.
- Creates PyTorch
Class Weight Calculation:
- If class weights file exists, loads the weights.
- Otherwise, calculates class weights using
sklearnbased on label frequency in the training set. - Since the class weight calculation is computationally heavy, it saves the computed weights for future on disk.
Model Initialization:
- Instantiates the model and optimizer from the configuration.
- Logs the model architecture for inspection.
Criterion Selection:
- Configures the loss function:
- Uses
CrossEntropyLossorFocalLossbased on the configuration. - Includes calculated class weights and an ignore index (
255) for "ignore" labels.
- Uses
- Configures the loss function:
Trainer Setup:
- Initializes the
Trainerclass with the model, criterion, optimizer, and training settings.
- Initializes the
Run Training and Testing:
- Calls the
trainer.runmethod to train and validate the model. - Evaluates the model on the test set using
trainer.test.
- Calls the
Script Entry Point:
- Ensures the
mainfunction is executed when the script is run directly.
- Ensures the
import train
When training a model on the i4DS Slurm Cluster, the train.py script is not the first interface called by the user. Instead, in the scripts/ directory, a submit_job.sh script is provided that initializes the environment and calls the train.py script with appropriate arguments and additionally cluster resource allocation arguments.
Hydra ConfigurationΒΆ
To orchestrate all the different following experiments and models (training configurations), I have added a Hydra setup to this project.
The configuration files are located in configs/ where each configurable part of the training process is defined in a separate YAML file. The experiments-directory houses all the different run configurations for this mini challenge. Whenever train.py experiment=[experiment_name] is called, the train.py script will inject the specified experiment configuration and load all specified parameters into each corresponding component.
Loss FunctionsΒΆ
Weighted Cross EntropyΒΆ
As a default loss I have used PyTorch's CrossEntropyLoss with class weights as already mentioned in the training script section. The class weights are calculated based on the frequency of each class in the training set, to give more importance to underrepresented classes.
Weighted Cross Entropy is mathematically defined as $$L=-\frac{1}{N}\sum_{i=1}^N\sum_{c=1}^C w_c\cdot y_{i,c}\log(p_{i,c})$$ So it is the same as regular Cross Entropy, except a weight $w$ of class $c$ is multiplied to the entropy of that class.
A not so apparent nuance here is that the weights for this loss function have to be the inverse frequency. They can't be just the frequency of each class. This is because the loss function is minimized and thus the weights have to be the inverse of the frequency to give more importance to underrepresented classes: $$w_c=\frac{1}{f_c}$$ or in the normalized form as $$w_c=\frac{\frac{1}{f_c}}{\sum_{j=1}^C\frac{1}{f_j}}$$
In my implementation, the class weights get calculated by sklearn's compute_class_weight-function, which automatically calculates the normalized inverse frequency.
Focal LossΒΆ
I have also implemented a FocalLoss class as an alternative loss function according to Lin et al. (2017). The Focal Loss builds upon the Weighted Cross Entropy Loss by introducing a modulating factor that dynamically adjusts the loss contribution of each sample based on the model's confidence. This approach is particularly effective in addressing class imbalance and focusing the training process on harder, misclassified examples.
Mathematically, the Focal Loss is defined as:
$$ L = -\frac{1}{N} \sum_{i=1}^N \sum_{c=1}^C w_c \cdot (1 - p_{i,c})^\gamma \cdot y_{i,c} \log(p_{i,c}) $$
where:
- $N$ is the number of samples.
- $C$ is the number of classes.
- $w_c$ is the weight for class $c$, set to the inverse frequency to balance class importance.
- $p_{i,c}$ is the predicted probability for class $c$ of sample $i$.
- $y_{i,c}$ is the ground truth indicator (1 if sample $i$ belongs to class $c$, else $0$).
- $\gamma$ is the focusing parameter that determines the strength of the modulation.
In my implementation, the FocalLoss class allows for tuning the focusing parameter $\gamma$ to balance the trade-off between focusing on hard examples and maintaining stable training.
from core import FocalLoss
(Mean) IoU MetricΒΆ
To evaluate each trained model, I decided to focus on the Intersection over Union (IoU) metric. It is a common choice for semantic segmentation tasks as it provides a more informative measure of model performance than accuracy. It quantifies the overlap between the predicted segmentation and the ground truth segmentation for a given class: $$IoU=\frac{Intersection}{Union}$$ where:
- Intersection is the area of overlap between the predicted segmentation and the ground truth segmentation for that class.
- Union is the total area covered by the predicted segmentation and the ground truth mask for that class.
In my Trainer class, the IoU is once calculated for each class as already defined above but also as a "global" IoU which is just the mean IoU of all classes: $$mIoU=\frac{1}{C}\sum_{i=1}^C IoU_i$$
In the Trainer itself, the IoU is calculated using the JaccardIndex. The Jaccard Index is just another name for the IoU, just more commonly used in the context of set theory.
from analyzer import Analyzer
analyzer = Analyzer()
Setting up a Baseline & Validating itΒΆ
For a baseline I chose to construct a simple U-Net that I can train on a single batch. This will give me a good idea of how the training process will look like and if the proposed Trainer-class and train.py script works.
This experiment was configured according to configs/experiment/unet_baseline.yaml, so it has 3 encoding and 3 decoding layers which project down into a bottleneck and then back up to the original image size. The model is trained on a single batch of 8 images. The first encoding layer has 64 filters, then 128, then 256 and the decoding layer vice-versa. This simple model is trained for 50 epochs without early stopping. It otherwise inherits the default configs for each component in the Hydra setup.
To train on a single batch of batch_size=8 I cut down the dataset to 12 images using the overfit_test: true setting in the configuration file. This assigns 70% of the dataset to the training set, 20% to the validation set and 10% to the test set. The validation and test partition in this case have only 2 images each.
Note: From now on, the next plots will show test metrics from the model state at its lowest validation loss. The Trainer always saves the model weights at its lowest validation loss, as described in the Trainer section. This is to ensure that the model is evaluated at its best state.
from core import UNet
baseline_analyzer = Analyzer(UNet(num_classes=5,
encoder_dims=[64, 128, 256],
decoder_dims=[256, 128, 64]))
baseline_analyzer.plot('1b1fhr6l')
The results show quite an interesting behavior:
- The training loss decreases as expected, the validation loss however does not until around epoch
20. The model seems to converge at that point and seems to learn some general gist that is also beneficial for the few validation samples. - When looking at the Validation IoU, the curves show a similar behavior to the loss. The global IoU starts to more clearly/steadily increase after epoch
30as well as every class besides the road class which seems to be the most difficult in the two validation samples. - The test metrics are however really imbalanced.
Let's take a closer look at the samples we have passed into this first test to perhaps understand why the road class is so difficult to classify. First I partition the 12 samples with the same seed just like in the train.py-script to make sure we are looking at the same samples as the model. Then I plot the training subset:
from data.utils import split_dataset
baseline_subset = dataset[:12]
train_subset, val_subset, test_subset = split_dataset(baseline_subset, random_seed=1337)
plot_images_and_masks(train_subset, class_colors=class_colors_hex, figsize=(8, 25))
There are two conspicious samples, namely on row 4 and 5. The sample on row 4 is a sample we have already inspected during the exploration step and we noted that the road is falsely annotated by the dataset, giving it a label of sidewalk. The sample on row 5 has a correctly labeled road part but it is extremely small. I suspect these two samples to be the main reason for the road class to be so difficult to classify.
Perhaps the validation samples are also unluckily picked, so let's look at those next:
plot_images_and_masks(val_subset, class_colors=class_colors_hex, figsize=(10, 7))
Here we can also see that in the second sample the road class is falsely annotated. It has the label sky instead of road.
Looking at the samples showed to be a good idea since the problems became apparent and seem to not stem from my training process or model architecture.
OverfitΒΆ
Karpathy: Overfit
In the overfit stage I will now experiment with a more complex U-Net architecture (vanilla_unet.py).
The Vanilla U-Net has 4 encoding and 4 decoding layers. The model class has a parameter base_filters which determines the complexity of the model. The base_filters determine the number of filters in the first encoding layer's DoubleConv block. At each next encoding layer the previous number of filters gets doubled. The decoding layers then have the same number of filters as the corresponding encoding layer.
In this step I will first give the Vanilla U-Net a low number of filters, then gradually increase it and look at the performance of the model to determine where the best generalization regime may lie. Contrary to the overfit testing in the previous step I will now use the full dataset.
from core import VanillaUNet
Experiment 1: Vanilla U-Net with base_filters=32ΒΆ
The Vanilla U-Net first gets trained with an architecture that starts at 32 filters in the first encoding layer, up to 512 filters in the bottleneck.
I expect the model to overfit and not regularize well since it most likely will be too small to capture the complexity of the dataset.
overfit_1_analyzer = Analyzer(VanillaUNet(num_classes=5, base_filters=32))
overfit_1_analyzer.plot('wlkbib58')
The training loss decreases relatively smooth but converges already after ~30 epochs. The validation loss thus does not show a continual decrease but rather stagnates fast until moving into an overfit regime.
Experiment 2: Vanilla U-Net with base_filters=64ΒΆ
Next I will increase the number of base filters to 64 which increases the dimensionality of all encoding and decoding layers up into a latent bottleneck of 1024 filters.
overfit_2_analyzer = Analyzer(VanillaUNet(num_classes=5, base_filters=64))
overfit_2_analyzer.plot('7knyl98e')
The training loss now shows its capability to decrease further and not show clear convergence yet. The test IoU metrics also increased by a lot. The largest improvement can be seen in the car label - This could be thanks to the increased model variance (complexity) since that is most likely the hardest class to capture as it can occur in many different shapes, sizes/distances and at any spatial location.
Experiment 3: Vanilla U-Net with base_filters=128ΒΆ
Lastly I will increase the number of base filters to 128 which increases the dimensionality of all encoding and decoding layers up into a latent bottleneck of 2048 filters. We have already seen quite a significant improvement from experiment 1 to 2 and the training loss was not clearly converging yet. Just to validate the observations from experiment 2 I will now increase the model complexity one step further.
overfit_3_analyzer = Analyzer(VanillaUNet(num_classes=5, base_filters=128))
overfit_3_analyzer.plot('g1mw1zpz')
Starting from 128 filters the model can even improve more. The validation loss seems to reach a lower plateau than in experiment 1 or 2 and therefore even reaches higher IoU scores on the held out test set.
Summarizing the Overfitting ExperimentsΒΆ
To summarize the experiments here I compare all the runs:
analyzer.compare_runs(['wlkbib58', '7knyl98e', 'g1mw1zpz'])
I already have established most observations that we can again see in the comparison plot above. The biggest improvement happens between Experiment 1 and 2. The model seems to benefit from a larger complexity but slowly starts to stagnate when it comes to improvement between Experiment 2 and 3. An interesting sight is that the road and sky classes are always outperforming the scores on other labels, even staying the same between experiment 1 and 2. I assume this is because these classes are the most frequent and roughly always occur on the same spatial location while the other classes have a higher variance in their spatial distribution (as seen in the explorative step, looking at the class heatmaps).
If we look at the performance in regards to answering the research question, we can see that the model seems to be slightly more capable in City environments, though the difference does not look significant.
To see the "real world performance", I will now also look at some segmentation samples produced by my "winner model" in this stage:
overfit_3_analyzer.sample('g1mw1zpz', selected_samples)
Searching for model weights for run g1mw1zpz... Found model: unet_overfit_3_g1mw1zpz.pth!
The model generally seems to work pretty well. In most of the difficult scenes it got the most crucial parts (arguably road and car) correct. It seems to work well within shadowy scenes and wet streets (although the ground truth here did not correctly classify the road, but the predicted mask undoubtedly got most of the road correct). Also the reflection in the 5th sample was correctly classified as building and not vegetation.
The ground truth has some parts that are unlabeled, the model however was designed to classify everything as one of the top 5 classes, so naturally there are some parts, like the POV-car's cockpit and hood of the car that get classified as something.
One of the clearer struggles however seems to be the snow. The model classified the narrow plowed path as road but the snow on the sides, where cars are parked, got sometimes also classified as car.
RegularizationΒΆ
Karpathy: Regularize
In this section I now regularize the training process by adding weight_decay to the Adam optimizer. It adds an $L_2$ penalty term to the loss function as follows: $$L_\text{total}=L+\frac{\lambda}{2}||w||_2^2$$
Here, $L$ is just the Weighted Cross Entropy Loss or Focal Loss in my case and $w$ are the weights. $\lambda$ is a tunable hyperparameter that determines the strength of regularization, so this is essentially the weight_decay parameter that gets changed in the following experiments.
Weight decay in gradient based optimization can actually be implemented without explicitly adding the penalty to the loss but can be direclty incorporated into the weight update rule: $$w\leftarrow w-\eta\nabla L$$ then becomes $$w\leftarrow w-\eta\nabla L-\eta\lambda w$$ with the penalty term $\lambda w$ introducing the decay.
Experiment 1: Vanilla U-Net with weight_decay=0.001ΒΆ
I am first adding a smaller weight decay to the model to see how it affects the training process.
regularize_analyzer = Analyzer(VanillaUNet(num_classes=5, base_filters=128))
regularize_analyzer.plot('wqubevv6')
The addition of weight decay seems to be quite beneficial, the validation loss decreases more steadily however the IoU metrics on the held out test set is not really improving.
Experiment 2: Vanilla U-Net with weight_decay=0.01ΒΆ
Then I am adding a 10-fold larger weight decay of 0.01. The previous experiment still showed some tendencies of overfitting towards the end of the training process so a larger weight decay might help to regularize the model more.
regularize_analyzer.plot('j5p25e3z')
The loss curves are still decreasing however the validation loss now seems to have gotten a bit more noisy and the test IoU metrics seem to have gotten worse since the training now converges earlier. The noisier loss curve could be due to the fact that I did not reduce the learning rate simultaneously when increasing regularization.
regularize_analyzer.compare_runs(['wqubevv6', 'j5p25e3z'])
The lower weight decay of 0.001 seems to be the better choice, the weight decay of 0.01 seems to regularize too harshly. However, it can be said that regularization with weight decay may not result in the desired outcome, we lose some generalization capabilities across the test set compared to the overfitting experiments that reached a lower validation loss.
To also look at the "best model" in this experiment I will plot the selected samples as predicted by the model that was regularized with weight_decay=0.001:
regularize_analyzer.sample('wqubevv6', selected_samples)
Searching for model weights for run wqubevv6... Found model: unet_regularize_wd_0.001_wqubevv6.pth!
In general the samples also show fairly good capability of the model when it comes to road, car and vegetation. It however does clearly struggle at some points, at least more than the overfit experiment winner. In the shadowy scene the model has classified the right part of the sidewalk as car. It also has difficulties with the reflection in the right glass building of the 5th row's sample where it thought that the reflection is actually vegetation and not building.
For further training I will now set weight decay to an even lower value of 0.0001 to see if the model can generalize better.
Tuning the modelΒΆ
Karpathy: Tune
In this section I have decided to explore a range of configurations in a grid with different learning rates and both proposed loss functions.
tuning_grid = {
'3jqyweld': [0.00001, 'Cross Entropy'],
'powrgvfz': [0.0005, 'Cross Entropy'],
'10qppijg': [0.0001, 'Cross Entropy'],
'cu3the78': [0.001, 'Cross Entropy'],
'w8r0kyr4': [0.00001, 'Focal Loss'],
'w83bx002': [0.0005, 'Focal Loss'],
'7u2ywdlf': [0.0001, 'Focal Loss'],
'6imza8mb': [0.001, 'Focal Loss']
}
analyzer = Analyzer(device="cpu", project_name="dlbs", entity_name="okaynils")
analyzer.plot_grid_results(tuning_grid)
The winner of the grid search seems clearly to be the vanilla U-Net with Cross Entropy and a higher learning rate. Focal Loss might be too nuanced since the dataset isn't actually severely imbalanced so a simpler cost function does the job well, reaching a global IoU of 80% on the held out test partition.
Crowning a WinnerΒΆ
The final hyperparameter tuning, together with previous experiments results in the best model being the Vanilla U-Net with base_filters=128, weight_decay=0.0001, learning_rate=0.001 and CrossEntropyLoss.
Here I will now inspect this winner more closely.
winner_model_analyzer = Analyzer(VanillaUNet(num_classes=5, base_filters=128))
winner_model_analyzer.plot('cu3the78')
This run performed the best without showing any clear signs of overfitting in the loss curves yet. It steadily decreased throughout the epochs, reaching the lowest validation loss at epoch 29 before the model started to overfit and the training terminated due to early stopping.
The class specific IoUs also show a steady improvement but the general observation of road and sky outperforming the other classes seems to be consistent throughout all experiments.
winner_model_analyzer.sample('cu3the78', selected_samples)
Searching for model weights for run cu3the78... Found model: unet_ce_lr_0.001_cu3the78.pth!
The samples also validate the numerical evaluation results. The model has some struggles again when it comes to snow and more overstructured parts of scenes where it resorts to classifying car.
Squeeze The JuiceΒΆ
Karpathy: Squeeze the juice
Implementing Attention into the U-NetΒΆ
Another route I wanted to explore is the implementation of the attention mechanism into the U-Net.
In the realm of semantic segmentation, particularly for complex and dynamic environments such as those encountered in autonomous driving, effectively capturing and utilizing contextual information is really important and in some cases even life saving. My implementation of the Attention U-Net architecture integrates an attention mechanism into the traditional U-Net framework with which I hope to improve its capability to focus on relevant features while suppressing irrelevant ones.
from core import AttentionUNet
The AttentionUNet encapsulates the AttentionBlock class which serves as a gating unit that modulates the feature maps from the encoder before they are concatenated with the decoder's upsampled features. With this selective gating I hope to ensure that the decoder only receives the most pertinent spatial information.
FunctionalityΒΆ
There are three main components in the attention mechanism that I have implemented:
Projection Layers
- Gating Signal Transformation (
W_g): The gating signalg, originating from the decoder's deeper layers, is first projected to an intermediate feature space of dimensionF_intusing a1x1convolution followed by batch normalization. This transformation aligns the gating signal's dimensions with those of the skip connection: $$W_g(g)=BatchNorm(Conv(g))$$ - Skip Connection Transformation (
W_x): Similarly, the skip connectionxfrom the encoder is projected to the same intermediate feature space using another1x1convolution and batch normalization: $$W_x(x)=BatchNorm(Conv(x))$$
- Gating Signal Transformation (
Attention Coefficient Computation (
psior $\psi$)- Element-wise Summation and Activation: The transformed gating signal and skip connection are summed element-wise, followed by a
ReLUactivation to introduce non-linearity: $$\psi=ReLU(W_g(g)+W_x(x))$$ - Sigmoid Activation ($\sigma(x)$): I apply a final
1x1convolution and batch normalization, followed by a sigmoid activation to generate the attention coefficients. These coefficients, ranging between 0 and 1, act as a spatial mask that highlights salient features (which I later on look at more closely): $$\psi=\sigma(BatchNorm(Conv(\psi)))$$
- Element-wise Summation and Activation: The transformed gating signal and skip connection are summed element-wise, followed by a
Feature Modulation
- The original skip connection
xis multiplied element-wise by the attention coefficients psi ($\psi$). This effectively scaled the features based on their relevance to the current decoding stage: $$x_\text{attended}=x\times\psi$$
- The original skip connection
Training the Attention U-NetΒΆ
First I am training a simple Attention U-Net, this time I went with a base_filters=64 configuration since the Attention mechanism itself adds complexity (more parameters) to the model.
attn_unet_analyzer = Analyzer(AttentionUNet(num_classes=5, base_filters=64))
attn_unet_analyzer.plot('81ytcmrq')
Introducing the Attention mechanism also means that we start to add more variance to the model, or more complexity. This naturally leads to overfitting, so this may be the reason for the validation loss to not decrease as much as maybe expected.
I will explore the Attention mechanism further by applying Dropout layers. Previously I have only regularized the models directly in the loss function using the $L_2$-Norm (or weight decay), this time I hope to achieve more with Dropout layers. Dropout is generally insensitive to the input and output scaling so I hope to see a different result this time that allows me to regularize the model more effectively.
Additionally, to really "squeeze the juice out of it" I am training the Attention U-Net with a maximum of 100 epochs but still with early stopping applied. This gives the runs the option to further explore the training process and potentially reach a lower validation loss.
Experiment 1: Attention U-Net with dropout_prob=0.1ΒΆ
First I am applying a Dropout probability of 0.1 (10%). This is a relatively low dropout probability but I want to start low to see how the model reacts to it.
attn_unet_dropout_1_analyzer = Analyzer(AttentionUNet(num_classes=5, base_filters=64, dropout_prob=0.1))
attn_unet_dropout_1_analyzer.plot('bta62w43')
The validation loss now follows the training loss more closely, though, it must be noted that the training loss is inflated due to the dropout layers. Seeing a validation loss lower than the training loss may raise concerns but this is a natural effect of dropout. When looking at the IoU scores on the test set we can see that the model is clearly improving, the classification of building, road, sky and car is significantly higher this time. Only vegetation did not improve as much (only by about 2%).
Experiment 2: Attention U-Net with dropout_prob=0.2ΒΆ
Next, I will look at a Dropout probability of 0.2 (20%). The previous experiment showed that towards the end the validation loss started to increase again, resulting in early stopping so it makes sense to apply stronger regularization.
attn_unet_dropout_2_analyzer = Analyzer(AttentionUNet(num_classes=5, base_filters=64, dropout_prob=0.2))
attn_unet_dropout_2_analyzer.plot('s709aw35')
The validation loss now starts to show slow convergence and stabilization towards the end of the training process. Each class' IoU also starts to look less noisy and improves in a stable way. The test IoU metrics also get even better!
Experiment 3: Attention U-Net with dropout_prob=0.3ΒΆ
To see if I have reached a ceiling on regularization I will now apply a Dropout probability of 0.3.
attn_unet_dropout_3_analyzer = Analyzer(AttentionUNet(num_classes=5, base_filters=64, dropout_prob=0.3))
attn_unet_dropout_3_analyzer.plot('zhzkaiyx')
The loss curves now start to look like they have not started to converge yet. Comparing the IoU scores to the previous experiment we can see that we are also starting to lose some generalization capabilities since a Dropout probability of 30% may be too harsh.
analyzer.compare_runs(['81ytcmrq', 'bta62w43', 's709aw35', 'zhzkaiyx'])
We can clearly observe that dropout regularization has yielded positive results for the Attention U-Net. I was able to increase the global IoU on the test set by almost 10% compared to the Attention U-Net without regularization. The model that was trained with dropout_prob=0.2 seems to have performed the best when looking at the individual classes IoU scores as well as the IoU scores for either the city or non-city scenes.
I will therefore take a closer look at that model's performance.
Taking A Look At Saliency Attention MapsΒΆ
Last but not least I am interested in how the attention mechanism actually works on the data. For that I have added a plot_attention_maps-function to the Analyzer class.
What this function does is:
Model Check:
- The method first checks if the model is an instance of
AttentionUNet. If not, it exits with a message.
- The method first checks if the model is an instance of
Initialize Attention Maps:
- It initializes an empty list,
self._attention_maps, where the attention maps will be stored.
- It initializes an empty list,
Register Attention Hooks:
- Calls
_register_attention_hooksto attach forward hooks to eachAttentionBlockin the model. - The hook function
_attention_hookis defined to:- Extract input (
x) and gating (g) tensors. - Apply weights (
W_g,W_x) togandx. - Perform a
reluand apply another transformation (psi) on their sum. - Append the computed attention map
psitoself._attention_maps.
- Extract input (
- The hooks are stored in
self._attention_hooksfor future removal.
- Calls
Model Evaluation:
- Sets the model in evaluation mode (
self.model.eval()), disabling dropout and batch norm updates. - The input
image_tensoris moved to the appropriate device (e.g., GPU). - The model processes the input in
with torch.no_grad(), ensuring no gradients are computed.
- Sets the model in evaluation mode (
Plot Attention Maps:
- After the model processes the input,
_plot_collected_attention_mapsis called to visualize the attention maps. - Each map is upsampled to match the input size and displayed using
matplotlib.
- After the model processes the input,
Remove Attention Hooks:
- Calls
_detach_attention_hooksto remove the forward hooks and clear references to prevent memory leaks. - The
analyzerattribute is removed from eachAttentionBlock.
- Calls
As mentioned I will now primarily look at the "winner model" in the Attention experiments, which was the Attention U-Net that was trained with with a dropout rate of 20%.
winner_attention_model = AttentionUNet(num_classes=5, base_filters=64, dropout_prob=0.2).to("cuda")
winner_attention_model_analyzer = Analyzer(model=winner_attention_model, device="cuda")
winner_attention_model_analyzer._load_model_weights('s709aw35')
winner_attention_model_analyzer.plot_attention_maps(selected_samples, figsize=(15, 3))
Searching for model weights for run s709aw35... Found model: unet_attn_dropout_0.2_s709aw35.pth!
As a result we can see the $\psi$-Coefficients of each attention layer for our selected samples.
The first attention layer seems to focus on the general location of the objects, it especially attends to the sides of the streets. This could be due to the fact that often times the sides are the biggest discriminator for what will be in a scene. For example, in the second sample, the image has a lot of different objects. The attention layer at that point attends to all objects on the sides where compared with the last sample, the attention layer does not have too many objects that are "hard to classify" as the scene is mostly just road and sky. Though at this last sample, there is an interesting activation happening at the top left of the image - I can't really make out why that is but perhaps it gets some information from the image having clear skies at this point.
The second layer goes more into detail and focuses on clearer shapes, this is a pattern that can be observed in the following attention maps as well. At attention map 4 we can see the mechanism has clear shapes to attend on. The $\psi$ coefficient seems to be especially high for parts of the image with a lot of structure as that is where it can gather a lot of information for labeling.
winner_attention_model_analyzer.sample('s709aw35', selected_samples)
Searching for model weights for run s709aw35... Found model: unet_attn_dropout_0.2_s709aw35.pth!
The samples of this "winner" attention model in my opinion shows the best performance yet. In complex packed scenes it classifies most pixels correctly, also when it comes to small shapes. The most impressive part for me is the first sample that even manages to get the POV-car correctly.
Overall it can be said that the dataset at hand poses many struggles that make modelling really hard. The elimination of number of classes improved the task's performance by a meaningful amount but the numerous "falsely" labeled objects make it hard for the simpler models to catch all nuances. The Attention U-Net really performed surprisingly well, even though we don't have too much data at hand - Attention mechanisms usually are really data hungry. The dropout regularization really boosted the Attention mechanism further and the model was able to generalize better.
When looking at bare numbers one may say that the Attention models performed slightly worse but after looking at the selected samples I am convinced that the Attention mechanism really helped the model to focus on the right parts of the image and getting better at complex scenes.
Comparing All WinnersΒΆ
analyzer.compare_runs(['g1mw1zpz', 'cu3the78', 'wqubevv6', 's709aw35'])
The regular Vanilla U-Net performed the ebst in therms of global IoU, also for both the city and non-city scenes and individual classes. But when comparing it to the Attention U-Net it made worse predictions in the observed selected samples. The Attention U-Net was able to capture more complex scenes and had clearer or more confident labels for objects where other models gave one objects multiple classes.
So to answer the Research Question, the Vanilla U-Net with 128 base filters and no regularization got the best metrics, but for a more generalizable model that may make more sense after visual evaluation, may be a better choice for the task.
OutlookΒΆ
Here are some ideas for further improvements:
- Augmentation of the dataset to increase the number of samples to train on
- Improve the datasets ground truths to have less falsely labeled objects by filtering out false ones or using a third party annotator